library(readstata13) # package to read Stata 13 dta files
library(tidyverse)
library(lubridate) # Working with dates
library(rstan)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = T)
set.seed(1)

stan_data <- readRDS("Results/stan_data_for_fit_age_sex_2021_01_18.rds")

args <- commandArgs(trailingOnly = TRUE)

vb_fit = readRDS(args[1])$vb_fit

filename <- gsub("vb", "MAP", args[1])

median_vb_estimates = vb_fit@model_pars %>% 
  (function(x) x[x!="lp__"]) %>% 
  lapply(function(parname){
    summary(vb_fit, pars = parname) %>% 
      .$summary %>% 
      as_tibble() %>% 
      .$`50%` %>% 
      list(pp = .) %>% 
      setNames(parname)
  }) %>% 
  Reduce(c, .)

Poisson_model_age_sex_effect_nu_shrunk_MAP <- "
data {
  int<lower=0> N;         // number of data points
  int<lower=0> Nmunicipalities;        
  int<lower=0> municipality[N];
  int<lower=0> Ymtas[N];
  int<lower=0> Nmtas[N];
  int<lower=0> covid[N]; //
  int<lower=0> Nage_class;         // 
  real age_class_scaled[N];
  int<lower=0> sex[N]; //
  real<lower=0, upper=1> rho_as[N]; // 
  real<lower=0> lambda;
}
parameters {
  real log_h0m_raw[Nmunicipalities];  
  real<lower=0, upper=1> nu_m[Nmunicipalities];
  real mu_h;
  real beta_age;
  real beta_sex;
  real<lower=0> sigma_h;
}
  transformed parameters {
    real log_h0m[Nmunicipalities];
    
    for (m in 1:Nmunicipalities){
      log_h0m[m] = log_h0m_raw[m]*sigma_h + mu_h;
    }
  }
model {
  real lambda_eff[N];
  mu_h ~ normal(-5, 5);
  sigma_h ~ normal(1, 20);
  beta_age ~ normal(0, 2);
  beta_sex ~ normal(0, 2);
  
  nu_m ~ exponential(lambda);
  log_h0m_raw ~ std_normal();
  
  for (i in 1:N) {
      lambda_eff[i] = Nmtas[i]*(exp(log_h0m[municipality[i]] + beta_age*age_class_scaled[i] + beta_sex*sex[i]) + covid[i]*nu_m[municipality[i]]*rho_as[i]);
    }
//  print(lambda_eff);
//  print(Ymtas);

  Ymtas ~ poisson(lambda_eff);
}
"
cat(file = "Scripts/Poisson_model_age_sex_effect_nu_shrunk_MAP.stan", Poisson_model_age_sex_effect_nu_shrunk_MAP)

system.time(m <- stan_model(file = "Scripts/Poisson_model_age_sex_effect_nu_shrunk_MAP.stan"))

stan_data$lambda = median_vb_estimates$lambda

MLE_fit = optimizing(object = m, data = stan_data, verbose = T, init = median_vb_estimates, iter = 100000)

print(paste("saving to", filename))

saveRDS(MLE_fit, file = filename)
